// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package segments

import (
	"context"
	"fmt"
	"math"

	"github.com/samber/lo"
	"go.opentelemetry.io/otel"
	"go.uber.org/zap"
	"google.golang.org/protobuf/proto"

	"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
	"github.com/milvus-io/milvus/internal/util/reduce"
	"github.com/milvus-io/milvus/internal/util/segcore"
	typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
	"github.com/milvus-io/milvus/pkg/v2/common"
	"github.com/milvus-io/milvus/pkg/v2/log"
	"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
	"github.com/milvus-io/milvus/pkg/v2/proto/segcorepb"
	"github.com/milvus-io/milvus/pkg/v2/util/conc"
	"github.com/milvus-io/milvus/pkg/v2/util/merr"
	"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
	"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)

var _ typeutil.ResultWithID = &internalpb.RetrieveResults{}

var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{}

func ReduceSearchOnQueryNode(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) {
	if info.GetIsAdvance() {
		return ReduceAdvancedSearchResults(ctx, results)
	}
	return ReduceSearchResults(ctx, results, info)
}

func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) {
	results = lo.Filter(results, func(result *internalpb.SearchResults, _ int) bool {
		return result != nil && result.GetSlicedBlob() != nil
	})

	if len(results) == 1 {
		log.Debug("Shortcut return ReduceSearchResults", zap.Any("result info", info))
		return results[0], nil
	}

	ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResults")
	defer sp.End()

	channelsMvcc := make(map[string]uint64)
	isTopkReduce := false
	isRecallEvaluation := false
	for _, r := range results {
		for ch, ts := range r.GetChannelsMvcc() {
			channelsMvcc[ch] = ts
		}
		if r.GetIsTopkReduce() {
			isTopkReduce = true
		}
		if r.GetIsRecallEvaluation() {
			isRecallEvaluation = true
		}
		// shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty
		if info.GetMetricType() == "" {
			info.SetMetricType(r.MetricType)
		}
	}
	log := log.Ctx(ctx)

	searchResultData, err := DecodeSearchResults(ctx, results)
	if err != nil {
		log.Warn("shard leader decode search results errors", zap.Error(err))
		return nil, err
	}
	log.Debug("shard leader get valid search results", zap.Int("numbers", len(searchResultData)))

	for i, sData := range searchResultData {
		log.Debug("reduceSearchResultData",
			zap.Int("result No.", i),
			zap.Int64("nq", sData.NumQueries),
			zap.Int64("topk", sData.TopK),
			zap.Int("ids.len", typeutil.GetSizeOfIDs(sData.Ids)),
			zap.Int("fieldsData.len", len(sData.FieldsData)))
	}

	searchReduce := InitSearchReducer(info)
	reducedResultData, err := searchReduce.ReduceSearchResultData(ctx, searchResultData, info)
	if err != nil {
		log.Warn("shard leader reduce errors", zap.Error(err))
		return nil, err
	}
	searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.GetNq(), info.GetTopK(), info.GetMetricType())
	if err != nil {
		log.Warn("shard leader encode search result errors", zap.Error(err))
		return nil, err
	}

	requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
		// delegator node won't be used to load sealed segment if stream node is enabled
		// and if growing segment doesn't exists, delegator won't produce any cost metrics
		// so we deprecate the EnableWorkerSQCostMetrics param
		return result.GetCostAggregation(), true
	})
	searchResults.CostAggregation = mergeRequestCost(requestCosts)
	if searchResults.CostAggregation == nil {
		searchResults.CostAggregation = &internalpb.CostAggregation{}
	}
	relatedDataSize := lo.Reduce(results, func(acc int64, result *internalpb.SearchResults, _ int) int64 {
		return acc + result.GetCostAggregation().GetTotalRelatedDataSize()
	}, 0)
	searchResults.CostAggregation.TotalRelatedDataSize = relatedDataSize
	searchResults.ChannelsMvcc = channelsMvcc
	searchResults.IsTopkReduce = isTopkReduce
	searchResults.IsRecallEvaluation = isRecallEvaluation
	return searchResults, nil
}

func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
	_, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceAdvancedSearchResults")
	defer sp.End()

	channelsMvcc := make(map[string]uint64)
	relatedDataSize := int64(0)
	isTopkReduce := false
	searchResults := &internalpb.SearchResults{
		IsAdvanced: true,
	}

	for index, result := range results {
		if result.GetIsTopkReduce() {
			isTopkReduce = true
		}
		relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize()
		for ch, ts := range result.GetChannelsMvcc() {
			channelsMvcc[ch] = ts
		}
		searchResults.NumQueries = result.GetNumQueries()
		// we just append here, no need to split subResult and reduce
		// defer this reduction to proxy
		subResult := &internalpb.SubSearchResults{
			MetricType:     result.GetMetricType(),
			NumQueries:     result.GetNumQueries(),
			TopK:           result.GetTopK(),
			SlicedBlob:     result.GetSlicedBlob(),
			SlicedNumCount: result.GetSlicedNumCount(),
			SlicedOffset:   result.GetSlicedOffset(),
			ReqIndex:       int64(index),
		}
		searchResults.SubResults = append(searchResults.SubResults, subResult)
	}
	searchResults.ChannelsMvcc = channelsMvcc
	requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
		// delegator node won't be used to load sealed segment if stream node is enabled
		// and if growing segment doesn't exists, delegator won't produce any cost metrics
		// so we deprecate the EnableWorkerSQCostMetrics param
		return result.GetCostAggregation(), true
	})
	searchResults.CostAggregation = mergeRequestCost(requestCosts)
	if searchResults.CostAggregation == nil {
		searchResults.CostAggregation = &internalpb.CostAggregation{}
	}
	searchResults.CostAggregation.TotalRelatedDataSize = relatedDataSize
	searchResults.IsTopkReduce = isTopkReduce
	return searchResults, nil
}

func SelectSearchResultData(dataArray []*schemapb.SearchResultData, resultOffsets [][]int64, offsets []int64, qi int64) int {
	var (
		sel                 = -1
		maxDistance         = -float32(math.MaxFloat32)
		resultDataIdx int64 = -1
	)
	for i, offset := range offsets { // query num, the number of ways to merge
		if offset >= dataArray[i].Topks[qi] {
			continue
		}

		idx := resultOffsets[i][qi] + offset
		distance := dataArray[i].Scores[idx]

		if distance > maxDistance {
			sel = i
			maxDistance = distance
			resultDataIdx = idx
		} else if distance == maxDistance {
			if sel == -1 {
				// A bad case happens where knowhere returns distance == +/-maxFloat32
				// by mistake.
				log.Warn("a bad distance is found, something is wrong here!", zap.Float32("score", distance))
			} else if typeutil.ComparePK(
				typeutil.GetPK(dataArray[i].GetIds(), idx),
				typeutil.GetPK(dataArray[sel].GetIds(), resultDataIdx)) {
				sel = i
				maxDistance = distance
				resultDataIdx = idx
			}
		}
	}
	return sel
}

func DecodeSearchResults(ctx context.Context, searchResults []*internalpb.SearchResults) ([]*schemapb.SearchResultData, error) {
	_, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "DecodeSearchResults")
	defer sp.End()

	results := make([]*schemapb.SearchResultData, 0)
	for _, partialSearchResult := range searchResults {
		if partialSearchResult.SlicedBlob == nil {
			continue
		}

		var partialResultData schemapb.SearchResultData
		err := proto.Unmarshal(partialSearchResult.SlicedBlob, &partialResultData)
		if err != nil {
			return nil, err
		}

		results = append(results, &partialResultData)
	}
	return results, nil
}

func EncodeSearchResultData(ctx context.Context, searchResultData *schemapb.SearchResultData,
	nq int64, topk int64, metricType string,
) (searchResults *internalpb.SearchResults, err error) {
	_, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "EncodeSearchResultData")
	defer sp.End()

	searchResults = &internalpb.SearchResults{
		Status:     merr.Success(),
		NumQueries: nq,
		TopK:       topk,
		MetricType: metricType,
		SlicedBlob: nil,
	}
	slicedBlob, err := proto.Marshal(searchResultData)
	if err != nil {
		return nil, err
	}
	if searchResultData != nil && searchResultData.Ids != nil && typeutil.GetSizeOfIDs(searchResultData.Ids) != 0 {
		searchResults.SlicedBlob = slicedBlob
	}
	return
}

func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, param *mergeParam) (*internalpb.RetrieveResults, error) {
	log := log.Ctx(ctx)
	log.Debug("mergeInternelRetrieveResults",
		zap.Int64("limit", param.limit),
		zap.Int("resultNum", len(retrieveResults)),
	)
	if len(retrieveResults) == 1 {
		return retrieveResults[0], nil
	}
	var (
		ret = &internalpb.RetrieveResults{
			Status: merr.Success(),
			Ids:    &schemapb.IDs{},
		}
		skipDupCnt int64
		loopEnd    int
	)

	_, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeInternalRetrieveResult")
	defer sp.End()

	validRetrieveResults := []*TimestampedRetrieveResult[*internalpb.RetrieveResults]{}
	relatedDataSize := int64(0)
	hasMoreResult := false
	for _, r := range retrieveResults {
		ret.AllRetrieveCount += r.GetAllRetrieveCount()
		relatedDataSize += r.GetCostAggregation().GetTotalRelatedDataSize()
		size := typeutil.GetSizeOfIDs(r.GetIds())
		if r == nil || len(r.GetFieldsData()) == 0 || size == 0 {
			continue
		}
		tr, err := NewTimestampedRetrieveResult(r)
		if err != nil {
			return nil, err
		}
		validRetrieveResults = append(validRetrieveResults, tr)
		loopEnd += size
		hasMoreResult = hasMoreResult || r.GetHasMoreResult()
	}
	ret.HasMoreResult = hasMoreResult

	if len(validRetrieveResults) == 0 {
		return ret, nil
	}

	if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) {
		loopEnd = int(param.limit)
	}

	ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].Result.GetFieldsData(), int64(loopEnd))
	idTsMap := make(map[interface{}]int64)
	cursors := make([]int64, len(validRetrieveResults))

	var retSize int64
	maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
	for j := 0; j < loopEnd; {
		sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors)
		if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) {
			break
		}

		pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
		ts := validRetrieveResults[sel].Timestamps[cursors[sel]]
		if _, ok := idTsMap[pk]; !ok {
			typeutil.AppendPKs(ret.Ids, pk)
			retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel])
			idTsMap[pk] = ts
			j++
		} else {
			// primary keys duplicate
			skipDupCnt++
			if ts != 0 && ts > idTsMap[pk] {
				idTsMap[pk] = ts
				typeutil.DeleteFieldData(ret.FieldsData)
				retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel])
			}
		}

		// limit retrieve result to avoid oom
		if retSize > maxOutputSize {
			return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize)
		}

		cursors[sel]++
	}

	if skipDupCnt > 0 {
		log.Debug("skip duplicated query result while reducing internal.RetrieveResults", zap.Int64("dupCount", skipDupCnt))
	}

	requestCosts := lo.FilterMap(retrieveResults, func(result *internalpb.RetrieveResults, _ int) (*internalpb.CostAggregation, bool) {
		// delegator node won't be used to load sealed segment if stream node is enabled
		// and if growing segment doesn't exists, delegator won't produce any cost metrics
		// so we deprecate the EnableWorkerSQCostMetrics param
		return result.GetCostAggregation(), true
	})
	ret.CostAggregation = mergeRequestCost(requestCosts)
	if ret.CostAggregation == nil {
		ret.CostAggregation = &internalpb.CostAggregation{}
	}
	ret.CostAggregation.TotalRelatedDataSize = relatedDataSize
	return ret, nil
}

func getTS(i *internalpb.RetrieveResults, idx int64) uint64 {
	if i.FieldsData == nil {
		return 0
	}
	for _, fieldData := range i.FieldsData {
		fieldID := fieldData.FieldId
		if fieldID == common.TimeStampField {
			res := fieldData.GetScalars().GetLongData().Data
			return uint64(res[idx])
		}
	}
	return 0
}

func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, param *mergeParam, segments []Segment, plan *RetrievePlan, manager *Manager) (*segcorepb.RetrieveResults, error) {
	ctx, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults")
	defer span.End()

	log := log.Ctx(ctx)
	log.Debug("mergeSegcoreRetrieveResults",
		zap.Int64("limit", param.limit),
		zap.Int("resultNum", len(retrieveResults)),
	)
	var (
		ret = &segcorepb.RetrieveResults{
			Ids: &schemapb.IDs{},
		}

		skipDupCnt int64
		loopEnd    int
	)

	validRetrieveResults := []*TimestampedRetrieveResult[*segcorepb.RetrieveResults]{}
	validSegments := make([]Segment, 0, len(segments))
	hasMoreResult := false
	for i, r := range retrieveResults {
		size := typeutil.GetSizeOfIDs(r.GetIds())
		ret.AllRetrieveCount += r.GetAllRetrieveCount()
		if r == nil || len(r.GetOffset()) == 0 || size == 0 {
			log.Debug("filter out invalid retrieve result")
			continue
		}
		tr, err := NewTimestampedRetrieveResult(r)
		if err != nil {
			return nil, err
		}
		validRetrieveResults = append(validRetrieveResults, tr)
		if plan.IsIgnoreNonPk() {
			validSegments = append(validSegments, segments[i])
		}
		loopEnd += size
		hasMoreResult = r.GetHasMoreResult() || hasMoreResult
	}
	ret.HasMoreResult = hasMoreResult

	if len(validRetrieveResults) == 0 {
		return ret, nil
	}

	var limit int = -1
	if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) {
		limit = int(param.limit)
	}

	ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].Result.GetFieldsData(), int64(loopEnd))
	cursors := make([]int64, len(validRetrieveResults))
	idTsMap := make(map[any]int64, limit*len(validRetrieveResults))

	var availableCount int
	var retSize int64
	maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()

	type selection struct {
		batchIndex  int   // index of validate retrieve results
		resultIndex int64 // index of selection in selected result item
		offset      int64 // offset of the result
	}

	var selections []selection

	for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ {
		sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors)
		if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) {
			break
		}

		pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
		ts := validRetrieveResults[sel].Timestamps[cursors[sel]]
		if _, ok := idTsMap[pk]; !ok {
			typeutil.AppendPKs(ret.Ids, pk)
			selections = append(selections, selection{
				batchIndex:  sel,
				resultIndex: cursors[sel],
				offset:      validRetrieveResults[sel].Result.GetOffset()[cursors[sel]],
			})
			idTsMap[pk] = ts
			availableCount++
		} else {
			// primary keys duplicate
			skipDupCnt++
			if ts != 0 && ts > idTsMap[pk] {
				idTsMap[pk] = ts
				idx := len(selections) - 1
				for ; idx >= 0; idx-- {
					selection := selections[idx]
					pkValue := typeutil.GetPK(validRetrieveResults[selection.batchIndex].GetIds(), selection.resultIndex)
					if pk == pkValue {
						break
					}
				}
				if idx >= 0 {
					selections[idx] = selection{
						batchIndex:  sel,
						resultIndex: cursors[sel],
						offset:      validRetrieveResults[sel].Result.GetOffset()[cursors[sel]],
					}
				}
			}
		}

		cursors[sel]++
	}

	if skipDupCnt > 0 {
		log.Debug("skip duplicated query result while reducing segcore.RetrieveResults", zap.Int64("dupCount", skipDupCnt))
	}

	if !plan.IsIgnoreNonPk() {
		// target entry already retrieved, don't do this after AppendPKs for better performance. Save the cost everytime
		// judge the `!plan.ignoreNonPk` condition.
		_, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
		defer span2.End()
		ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].Result.GetFieldsData(), int64(len(selections)))
		// cursors = make([]int64, len(validRetrieveResults))
		for _, selection := range selections {
			// cannot use `cursors[sel]` directly, since some of them may be skipped.
			retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[selection.batchIndex].Result.GetFieldsData(), selection.resultIndex)

			// limit retrieve result to avoid oom
			if retSize > maxOutputSize {
				return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize)
			}
		}
	} else {
		// target entry not retrieved.
		ctx, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-RetrieveByOffsets-AppendFieldData")
		defer span2.End()
		segmentResults := make([]*segcorepb.RetrieveResults, len(validRetrieveResults))
		groups := lo.GroupBy(selections, func(sel selection) int {
			return sel.batchIndex
		})
		futures := make([]*conc.Future[any], 0, len(groups))
		for i, selections := range groups {
			idx, theOffsets := i, lo.Map(selections, func(sel selection, _ int) int64 { return sel.offset })
			future := GetSQPool().Submit(func() (any, error) {
				var r *segcorepb.RetrieveResults
				var err error
				if err := doOnSegment(ctx, manager, validSegments[idx], func(ctx context.Context, segment Segment) error {
					r, err = segment.RetrieveByOffsets(ctx, &segcore.RetrievePlanWithOffsets{
						RetrievePlan: plan,
						Offsets:      theOffsets,
					})
					return err
				}); err != nil {
					return nil, err
				}
				segmentResults[idx] = r
				return nil, nil
			})
			futures = append(futures, future)
		}
		// Must be BlockOnAll operation here.
		// If we perform a fast fail here, the cgo struct like `plan` will be used after free, unsafe memory access happens.
		if err := conc.BlockOnAll(futures...); err != nil {
			return nil, err
		}

		for _, r := range segmentResults {
			if len(r.GetFieldsData()) != 0 {
				ret.FieldsData = typeutil.PrepareResultFieldData(r.GetFieldsData(), int64(len(selections)))
				break
			}
		}

		_, span3 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
		defer span3.End()
		// retrieve result is compacted, use 0,1,2...end
		segmentResOffset := make([]int64, len(segmentResults))
		for _, selection := range selections {
			retSize += typeutil.AppendFieldData(ret.FieldsData, segmentResults[selection.batchIndex].GetFieldsData(), segmentResOffset[selection.batchIndex])
			segmentResOffset[selection.batchIndex]++
			// limit retrieve result to avoid oom
			if retSize > maxOutputSize {
				return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize)
			}
		}
	}

	return ret, nil
}

func mergeInternalRetrieveResultsAndFillIfEmpty(
	ctx context.Context,
	retrieveResults []*internalpb.RetrieveResults,
	param *mergeParam,
) (*internalpb.RetrieveResults, error) {
	mergedResult, err := MergeInternalRetrieveResult(ctx, retrieveResults, param)
	if err != nil {
		return nil, err
	}

	if err := typeutil2.FillRetrieveResultIfEmpty(typeutil2.NewInternalResult(mergedResult), param.outputFieldsId, param.schema); err != nil {
		return nil, fmt.Errorf("failed to fill internal retrieve results: %s", err.Error())
	}

	return mergedResult, nil
}

func mergeSegcoreRetrieveResultsAndFillIfEmpty(
	ctx context.Context,
	retrieveResults []*segcorepb.RetrieveResults,
	param *mergeParam,
	segments []Segment,
	plan *RetrievePlan,
	manager *Manager,
) (*segcorepb.RetrieveResults, error) {
	mergedResult, err := MergeSegcoreRetrieveResults(ctx, retrieveResults, param, segments, plan, manager)
	if err != nil {
		return nil, err
	}

	if err := typeutil2.FillRetrieveResultIfEmpty(typeutil2.NewSegcoreResults(mergedResult), param.outputFieldsId, param.schema); err != nil {
		return nil, fmt.Errorf("failed to fill segcore retrieve results: %s", err.Error())
	}

	return mergedResult, nil
}
